package drds.plus.sql_process.optimizer.pre_processor;

import drds.plus.sql_process.abstract_syntax_tree.ObjectCreateFactory;
import drds.plus.sql_process.abstract_syntax_tree.expression.NullValue;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.Item;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.column.Column;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.function.*;
import drds.plus.sql_process.abstract_syntax_tree.expression.order_by.OrderBy;
import drds.plus.sql_process.abstract_syntax_tree.node.query.Query;
import drds.plus.sql_process.optimizer.OptimizerException;
import drds.plus.sql_process.utils.DnfFilters;
import drds.plus.sql_process.utils.UniqueIdGenerator;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * 预先处理子查询
 *
 * <pre>
 * 1. 查找所有filter/select中的子查询
 * 2. 使用结算结果替换filter/select中的子查询
 * 3. 转移子查询到subQueryFilter中,执行计划按照resultFilter单独处理,避免下推
 * </pre>
 */
public class SubQueryPreProcessor {

    public static Query opitmize(Query query) throws OptimizerException {
        if (query.getWhere() != null) {
            WhereFilterAndSubQueryFunctionFilter whereFilterAndSubQueryFunctionFilter = split(query.getWhere());
            if (whereFilterAndSubQueryFunctionFilter.subQueryFunctionFilter != null) {
                query.setWhereAndSetNeedBuild(whereFilterAndSubQueryFunctionFilter.filter);
                query.setAllWhereFilter(whereFilterAndSubQueryFunctionFilter.filter); // 也清理下all where
                query.setSubQueryFunctionFilter(whereFilterAndSubQueryFunctionFilter.subQueryFunctionFilter);
            }
        }

        List<Function> functionList = new ArrayList<Function>();
        findSubQueryFunctionFromQueryNode(query, functionList, null, true);
        if (functionList.size() > 0) {
            throw new OptimizerException("暂不支持非where条件的correlate subquery");
        }
        return query;
    }

    private static WhereFilterAndSubQueryFunctionFilter split(Filter filter) {
        List<Function> functionList = new ArrayList<Function>();
        findSubQueryFunctionFromFilter(filter, functionList, null, true);
        //
        WhereFilterAndSubQueryFunctionFilter whereFilterAndSubQueryFunctionFilter = new WhereFilterAndSubQueryFunctionFilter();
        if (functionList.size() == 0) {// 不存在子查询
            whereFilterAndSubQueryFunctionFilter.filter = filter;
            return whereFilterAndSubQueryFunctionFilter;
        }
        if (!DnfFilters.isCnf(filter)) { // 存在or关系
            whereFilterAndSubQueryFunctionFilter.subQueryFunctionFilter = filter;
            whereFilterAndSubQueryFunctionFilter.filter = null;
            return whereFilterAndSubQueryFunctionFilter;
        } else {
            List<Filter> filterList = DnfFilters.toDnfFilterList(filter);
            List<Filter> newFilterList = new ArrayList<Filter>();
            Filter subQueryFilter = null;
            for (Filter $filter : filterList) {
                functionList.clear();
                findSubQueryFunctionFromFilter($filter, functionList, null, true);
                if (functionList.size() > 0) { // 存在子查询
                    subQueryFilter = DnfFilters.and(subQueryFilter, $filter);
                } else {
                    newFilterList.add($filter);// 保留在老filter中
                }
            }
            whereFilterAndSubQueryFunctionFilter.subQueryFunctionFilter = subQueryFilter;
            whereFilterAndSubQueryFunctionFilter.filter = DnfFilters.andDnfFilterList(newFilterList);
            return whereFilterAndSubQueryFunctionFilter;
        }

    }

    private static void findSubQueryFunctionFromQueryNode(Query query, List<Function> functionList, Map<Long, Object> subQueryCorrelateFilterIdToValueMap, boolean existCorrelatedSubQuery) {
        findSubQueryFunctionFromFilter(query.getIndexQueryKeyFilter(), functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
        findSubQueryFunctionFromFilter(query.getWhere(), functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
        findSubQueryFunctionFromFilter(query.getResultFilter(), functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
        findSubQueryFunctionFromFilter(query.getOtherJoinOnFilter(), functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
        findSubQueryFunctionFromFilter(query.getHaving(), functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
        List<Item> itemList = new ArrayList<Item>();
        for (Item item : query.getSelectItemList()) {
            // 可能替换了subQuery
            Object selectItem = getSubQueryFunctionValueOrAddSubQueryFunctionFromSelectItem(item, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
            if (selectItem != null) {
                if (selectItem instanceof Item) {
                    itemList.add((Item) selectItem);
                } else {
                    itemList.add(buildConstantFilter(selectItem, item.getAlias()));
                }
            } else {
                itemList.add(item);
            }
        }
        query.setSelectItemListAndSetNeedBuild(itemList);
        if (query.getOrderByList() != null) {
            for (OrderBy orderBy : query.getOrderByList()) {
                findSubQueryFromOrderBy(orderBy, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
            }
        }
        if (query.getGroupByList() != null) {
            for (OrderBy groupBy : query.getGroupByList()) {
                findSubQueryFromOrderBy(groupBy, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
            }
        }
    }

    private static void findSubQueryFunctionFromFilter(Filter filter, List<Function> functionList, Map<Long, Object> subQueryCorrelateFilterIdToValueMap, boolean existCorrelatedSubQuery) {
        if (filter == null) {
            return;
        }
        if (filter instanceof LogicalOperationFilter) {
            for (Filter subFilter : ((LogicalOperationFilter) filter).getFilterList()) {
                findSubQueryFunctionFromFilter(subFilter, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
            }
        } else if (filter instanceof OrsFilter) {
            for (Filter subFilter : ((OrsFilter) filter).getFilterList()) {
                findSubQueryFunctionFromFilter(subFilter, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
            }
        } else {
            findSubQueryFromBooleanFilter((BooleanFilter) filter, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
        }

    }

    private static void findSubQueryFromBooleanFilter(BooleanFilter booleanFilter, List<Function> functionList, Map<Long, Object> subQueryCorrelateFilterIdToValueMap, boolean existCorrelatedSubQuery) {
        if (booleanFilter == null) {
            return;
        }
        Object column = booleanFilter.getColumn();
        Object value = booleanFilter.getValue();
        if (column instanceof Function && isSubQueryFunction((Function) column)) {
            if (((Function) column).getArgList().get(0) instanceof Query) {
                Query query = (Query) ((Function) column).getArgList().get(0);
                // 深度优先,尝试递归找一下
                findSubQueryFunctionFromQueryNode(query, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
                Object object = getSubQueryValue(query, subQueryCorrelateFilterIdToValueMap);
                // 可能已经计算出了结果，替换一下
                if (object != null && !(object instanceof Query)) {
                    booleanFilter.setColumn(object);
                } else {
                    addSubQueryFunction(functionList, (Function) column, existCorrelatedSubQuery);
                }
            }
        } else if (column instanceof Item) {
            Object selectItem = getSubQueryFunctionValueOrAddSubQueryFunctionFromSelectItem((Item) column, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
            if (selectItem != null) {
                booleanFilter.setColumn(selectItem);
            }
        }
        // subQuery，比如WHERE ID = (SELECT ID FROM A)
        if (value instanceof Function && isSubQueryFunction((Function) value)) {
            if (((Function) value).getArgList().get(0) instanceof Query) {
                Query query = (Query) ((Function) value).getArgList().get(0);
                // 深度优先,尝试递归找一下
                findSubQueryFunctionFromQueryNode(query, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
                Object subQueryValue = getSubQueryValue(query, subQueryCorrelateFilterIdToValueMap);
                // 可能已经计算出了结果，替换一下
                if (subQueryValue != null && !(subQueryValue instanceof Query)) {
                    booleanFilter.setValue(subQueryValue);
                } else {
                    addSubQueryFunction(functionList, (Function) value, existCorrelatedSubQuery);
                }
            }
        } else if (value instanceof Item) {
            Object selectItem = getSubQueryFunctionValueOrAddSubQueryFunctionFromSelectItem((Item) value, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
            if (selectItem != null) {
                booleanFilter.setValue(selectItem);
            }
        }

        if (booleanFilter.getOperation() == Operation.in) {
            List<Object> valueList = booleanFilter.getValueList();
            if (valueList != null && !valueList.isEmpty()) {
                // in的子查询
                if (valueList.get(0) instanceof Function && isSubQueryFunction((Function) valueList.get(0))) {
                    if (((Function) valueList.get(0)).getArgList().get(0) instanceof Query) {
                        Query query = (Query) ((Function) valueList.get(0)).getArgList().get(0);
                        // 深度优先,尝试递归找一下
                        findSubQueryFunctionFromQueryNode(query, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
                        Object subQueryValue = getSubQueryValue(query, subQueryCorrelateFilterIdToValueMap);
                        // 可能已经计算出了结果，替换一下
                        if (subQueryValue != null && !(subQueryValue instanceof Query)) {
                            booleanFilter.setValueList((List<Object>) subQueryValue); // 一定会是list,否则就是执行器的bug
                        } else {
                            addSubQueryFunction(functionList, (Function) valueList.get(0), existCorrelatedSubQuery);
                        }
                    }
                }
            }
        }

    }

    private static void findSubQueryFromOrderBy(OrderBy orderBy, List<Function> functionList, Map<Long, Object> subQueryCorrelateFilterIdToValueMap, boolean existCorrelatedSubQuery) {
        if (orderBy.getItem() instanceof Item) {
            Object selectItem = getSubQueryFunctionValueOrAddSubQueryFunctionFromSelectItem(orderBy.getItem(), functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
            if (selectItem != null) {
                if (!(selectItem instanceof Item)) {
                    orderBy.setColumn(buildConstantFilter(selectItem, orderBy.getItem().getAlias()));
                } else {
                    orderBy.setColumn((Item) selectItem);
                }
            }
        }
    }

    private static Object getSubQueryFunctionValueOrAddSubQueryFunctionFromSelectItem(Item item, List<Function> functionList, Map<Long, Object> subQueryCorrelateFilterIdToValueMap, boolean existCorrelatedSubQuery) {
        if (item instanceof Filter) {
            findSubQueryFunctionFromFilter((Filter) item, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
        } else if (item instanceof Function) {
            return getSubQueryFunctionValueOrAddSubQueryFunction((Function) item, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
        } else if (item instanceof Column) {
            return findColumnValue((Column) item, subQueryCorrelateFilterIdToValueMap);
        }

        return null;
    }

    private static Object findColumnValue(Column column, Map<Long, Object> subQueryCorrelateFilterIdToValueMap) {
        if (column.getCorrelateSubQueryItemId() > 0L) {
            if (subQueryCorrelateFilterIdToValueMap == null) {
                return null;
            }
            return subQueryCorrelateFilterIdToValueMap.get(column.getCorrelateSubQueryItemId());
        }
        return null;
    }


    private static Object getSubQueryFunctionValueOrAddSubQueryFunction(Function function, List<Function> functionList, Map<Long, Object> subQueryCorrelateFilterIdToValueMap, boolean existCorrelatedSubQuery) {
        for (Object arg : function.getArgList()) {
            if (arg instanceof Item) {
                Object selectItem = getSubQueryFunctionValueOrAddSubQueryFunctionFromSelectItem((Item) arg, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
                if (selectItem != null) {
                    return selectItem;
                }
            } else if (arg instanceof Query) { // scalar subQuery
                // 深度优先,尝试递归找一下
                findSubQueryFunctionFromQueryNode((Query) arg, functionList, subQueryCorrelateFilterIdToValueMap, existCorrelatedSubQuery);
                Object subQueryValue = getSubQueryValue((Query) arg, subQueryCorrelateFilterIdToValueMap);
                if (subQueryValue != null && !(subQueryValue instanceof Query)) {
                    return subQueryValue;
                } else {
                    addSubQueryFunction(functionList, function, existCorrelatedSubQuery);
                }
            }
        }

        return null;
    }

    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    public static Function getNextSubQueryFunction(Query query) {
        return getNextSubQueryFunction(query, false);
    }

    /**
     * 查找filter需要处理的下一个subQuery (采取了深度优先遍历,链表中第一个为叶子节点)
     */
    public static Function getNextSubQueryFunction(Query query, boolean existCorrelatedSubQuery) {
        if (query == null) {
            return null;
        }

        List<Function> functionList = findSubQueryFunctionList(query, existCorrelatedSubQuery);
        if (functionList.size() > 0) {
            return functionList.get(0);
        } else {
            return null;
        }
    }

    /**
     * 查找filter中所有的subquery
     */
    public static List<Function> findSubQueryFunctionList(Query query, boolean existCorrelatedSubQuery) throws OptimizerException {
        List<Function> functionList = new ArrayList<Function>();
        if (query == null) {
            return functionList;
        }

        findSubQueryFunctionFromQueryNode(query, functionList, null, existCorrelatedSubQuery);
        findSubQueryFunctionFromFilter(query.getSubQueryFunctionFilter(), functionList, null, existCorrelatedSubQuery);
        return functionList;
    }

    /**
     * 根据执行计划的计算结果，替换subquery
     */
    public static Function subQueryAssignValueAndReturnNextSubQueryFunction(Query query, Map<Long, Object> subQueryCorrelateFilterIdToValueMap) throws OptimizerException {
        List<Function> functionList = new ArrayList<Function>();
        findSubQueryFunctionFromQueryNode(query, functionList, subQueryCorrelateFilterIdToValueMap, false);
        findSubQueryFunctionFromFilter(query.getSubQueryFunctionFilter(), functionList, subQueryCorrelateFilterIdToValueMap, false);
        if (functionList.size() > 0) {
            return functionList.get(0);
        } else {
            return null;
        }
    }

    //
    private static boolean isSubQueryFunction(Function function) {
        return function.getFunctionName().equals(FunctionName.subquery_list) || function.getFunctionName().equals(FunctionName.subquery_scalar);
    }

    private static Object getSubQueryValue(Query query, Map<Long, Object> subQueryCorrelateFilterIdToValueMap) {
        if (subQueryCorrelateFilterIdToValueMap == null) {
            return null;
        }

        if (subQueryCorrelateFilterIdToValueMap.containsKey(query.getSubQueryId())) {
            Object object = subQueryCorrelateFilterIdToValueMap.get(query.getSubQueryId());
            if (object == null) {
                return NullValue.getNullValue();
            }
            return object;
        }
        return null;
    }

    private static void addSubQueryFunction(List<Function> functionList, Function function, boolean existCorrelatedSubQuery) {
        if (functionList == null) {
            return;
        }
        boolean exist = false;
        Query query = (Query) function.getArgList().get(0);
        Long subQueryFilterId = query.getSubQueryId();
        if (subQueryFilterId == null || subQueryFilterId == 0) {
            query.setSubQueryId(UniqueIdGenerator.genCorrelateSubQueryItemId());
        } else {
            // 可能whereFilter和resultFilter中有重复的filter
            for (Function function1 : functionList) {
                if (((Query) function1.getArgList().get(0)).getSubQueryId().equals(subQueryFilterId)) {
                    exist = true;
                    break;
                }
            }
        }
        if (!exist) {
            functionList.add(function);
        }
        if (!existCorrelatedSubQuery) {
            // 清理下correlated
            clearFunctionListWhenExistCorrelatedSubQuery(functionList);
        }
    }

    /**
     * 处理下叶子节点是否存在correlated 查询
     */
    private static boolean clearFunctionListWhenExistCorrelatedSubQuery(List<Function> functionList) {
        if (functionList.size() > 0) {
            Query query = (Query) functionList.get(0).getArgList().get(0);
            if (query.isCorrelatedSubQuery()) {
                functionList.clear();
                return true;
            }
        }
        return false;
    }

    private static BooleanFilter buildConstantFilter(Object constant, String alias) {
        BooleanFilter booleanFilter = ObjectCreateFactory.createBooleanFilter();
        booleanFilter.setOperation(Operation.constant);
        booleanFilter.setColumn(constant);
        booleanFilter.setColumnName(constant.toString());
        booleanFilter.setAlias(alias);
        return booleanFilter;
    }

}
